from __future__ import absolute_import, division, print_function
import sys
# importing time module
import time
from datetime import datetime
import torch
import random
import torch.optim as optim
import torch.multiprocessing as mp
import os

sys.path.append('../')
from tqdm import tqdm
import lib

import lib.common_ptan as ptan

import numpy as np
import gym
import moenvs
from lib.utilities.MORL_utils import MOOfflineEnv
from lib.utilities.common_utils import make_config

from collections import namedtuple, deque
import copy
import wandb
import argparse
import gc

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--algo", default='Diffusion-QL', help="TD3+BC, Diffusion-QL, IQL, PEDA, BC, Prompt-MODT", type=str)
    parser.add_argument("--env", default="Hopper")  
    parser.add_argument("--seed", default=0, type=int)  
    parser.add_argument("--use_wandb", default=1, type=int)   
    parser.add_argument("--test_only", default=0, type=int)   
    parser.add_argument("--load_model", default=0, type=int)   
    parser.add_argument("--dataset", default='d4morl', help='d4morl, safe, cmo', type=str)  
    parser.add_argument("--dataset_type", default=None)
    parser.add_argument("--policy_freq", default=2, type=int)
    parser.add_argument("--gpu", default=0, type=int)
    parser.add_argument("--normalize_states", default=1, type=int)
    parser.add_argument("--pref_perturb_theta", default=0.0, type=float)
    parser.add_argument("--pref_gen_way", default=None, type=str, help='L1_return, highest_return, None')
    parser.add_argument("--record_label", default="", type=str)

    parser.add_argument("--time_steps", default=None, type=int)
    parser.add_argument("--eval_freq", default=100000, type=int)
    parser.add_argument("--gamma", default=0.995, type=float)
    parser.add_argument("--weight_bc_loss", default=None, type=float)
    parser.add_argument("--w_step_size_final_eval", default=0.01, type=float)
    
    parser.add_argument("--num_eval_env", default=5, type=int)
    parser.add_argument("--lr_decay", default=1, type=int)          
    parser.add_argument("--lr_critic", default=3e-4, type=float)    
    parser.add_argument("--lr_actor", default=3e-4, type=float)                  
    parser.add_argument("--policy_noise", default=0.2, type=float)
    parser.add_argument("--batch_size", default=256, type=int)
    parser.add_argument("--weight_num", default=1, type=int)
    parser.add_argument("--angle_loss_coeff", default=0.0, type=float)     
    
    parser.add_argument("--fixed_pref1", default=None, type=float)
    parser.add_argument("--fixed_pref1_traj_num", default=None, type=int) 
    parser.add_argument("--pre_sample_traj_num", default=None, type=int)
    parser.add_argument("--dataset_preprocess", default='exclude_test_data', type=str, help="exclude_test_data, None") 
    parser.add_argument("--reward_normalize", default=0, type=int) 

    parser.add_argument("--adpt_steps", default=1000, type=int)
    parser.add_argument("--adpt_lr", default=0.05, type=float)
    #parser.add_argument("--adpt_num_demo", default=128, type=int)
    parser.add_argument("--adpt_batch_demo", default=128, type=int)
    parser.add_argument("--adpt_batch_pref", default=64, type=int)
    parser.add_argument("--adpt_td_weight", default=0.01, type=float)
    parser.add_argument("--adpt_prior_weight", default=1.0, type=float)
    parser.add_argument("--adpt_entropy_weight", default=1.0, type=float)
    parser.add_argument("--adpt_td_type", default='qf', help='qf or vf', type=str)
    parser.add_argument("--adpt_fix_dim1", default=0, type=int)
    
    parser.add_argument("--finetune_step", default=1000, type=float)
    parser.add_argument("--finetune_lr", default=0.01, type=float)

    # IQL paras 
    parser.add_argument("--iql_beta", default=1.0/3, type=float)
    parser.add_argument("--iql_quantile", default=0.9, type=float)
    parser.add_argument("--iql_clip_score", default=100, type=float)
    parser.add_argument("--iql_clip_min_score", default=None, type=float)
    parser.add_argument("--iql_warmup_step", default=0, type=float)

    # Diffusion paras
    parser.add_argument("--diffusion_n_timesteps", default=5, type=int)

    # For safe RL
    parser.add_argument("--cvar_alpha", default=0.8, type=float)

    start_time = time.time()
    input_args = parser.parse_args()
    USE_WANDB = input_args.use_wandb
    SEED = input_args.seed
    env_name = input_args.env
    dataset_type = input_args.dataset_type
    args = lib.utilities.settings.HYPERPARAMS[env_name]()
    for arg in vars(input_args):
        if (arg=='time_steps' or arg=='weight_bc_loss') and getattr(input_args, arg) is None:
            continue
        setattr(args, arg, getattr(input_args, arg))    
    args.name = f"{env_name}_{input_args.dataset}_{dataset_type}_{args.num_objective}obj"
    writer = lib.utilities.common_utils.WandbWriter(USE_WANDB, "morl_adaptation", args, SEED)

    torch.manual_seed(SEED)
    random.seed(SEED)
    np.random.seed(SEED)  
    os.environ['CUDA_VISIBLE_DEVICES'] = f"{args.gpu}"
    device = torch.device(f"cuda" if args.cuda else "cpu")
    args.device = device

    # setup the environment
    args.num_eval_env = args.eval_episodes
    env_main = MOOfflineEnv(env_name, safe_obj_list=args.safe_obj_list, dataset_class=args.dataset, num_objective=args.num_objective)
    test_env_main = [(lambda: MOOfflineEnv(env_name, safe_obj_list=args.safe_obj_list, dataset_class=args.dataset, num_objective=args.num_objective, seed=i)) for i in range(args.num_eval_env)]

    #Initialize environment related arguments
    args.obs_shape = env_main.observation_space.shape[0]
    args.action_shape = env_main.action_space.shape[0]
    args.reward_size = len(env_main.reward_space)
    args.max_action = env_main.action_space.high
    args.max_episode_len = env_main._max_episode_steps if args.max_episode_len is None else args.max_episode_len
    args.writer = writer

    #Initialize the networks
    behavior_prior = None
    critic = lib.models.networks.Critic(args).to(device)
    value_function = lib.models.networks.ValueFunction(args).to(device)
    if args.algo=='TD3+BC' or args.algo=='BC':
        actor = lib.models.networks.Actor(args, is_gaussian=False).to(device)
    elif args.algo=='IQL':
        actor = lib.models.networks.Actor(args, is_gaussian=True).to(device)
    elif args.algo=='Diffusion-QL':
        from lib.diffusion.agents.diffusion import Diffusion
        from lib.diffusion.agents.model import MLP
        mlp_model = MLP(args.obs_shape, args.action_shape, args.reward_size, args.device).to(device)
        actor = Diffusion(mlp_model, args, n_timesteps=args.diffusion_n_timesteps).to(device)
    elif args.algo=='PEDA':
        from gym.spaces.box import Box
        from torch import nn
        from lib.PEDA.rvs.src.rvs.policies import RvS as Model
        state_dim = args.obs_shape + args.reward_size
        observation_space_place_holder = Box(low=np.zeros(state_dim), high=np.ones(state_dim),)
        action_space = Box(low=env_main.action_space.low, high=env_main.action_space.high, dtype=env_main.action_space.dtype)
        actor = Model(
            observation_space=observation_space_place_holder,
            action_space=action_space, state_dim=state_dim, act_dim=args.action_shape, pref_dim=args.reward_size, rtg_dim=args.reward_size,
            hidden_size=512, depth=3, learning_rate=1e-4, batch_size=64, activation_fn=nn.ReLU, dropout_p=0.1, unconditional_policy=False, reward_conditioning=True, env_name="",
        ).to(device=args.device)
        actor.state_dim = state_dim
        actor.act_dim = args.action_shape
        actor.pref_dim = args.reward_size
        actor.rtg_dim = args.reward_size
    elif args.algo=='Prompt-MODT':
        from lib.PEDA.modt.models.decision_transformer import DecisionTransformer as Model
        state_dim = args.obs_shape
        actor = Model(
            state_dim=state_dim,
            act_dim=args.action_shape,
            pref_dim=args.reward_size,
            rtg_dim=args.reward_size,
            hidden_size=512,
            max_length=20,
            eval_context_length=5,
            max_ep_len=args.max_episode_len,
            act_scale=torch.from_numpy(np.array(env_main.action_space.high)),
            use_pref=False,
            concat_state_pref=False,
            concat_rtg_pref=False,
            concat_act_pref=False,
            n_layer=3,
            n_head=1,
            n_inner=4*512,
            activation_function='relu',
            n_positions=1024,
            resid_pdrop=0.1,
            attn_pdrop=0.1
        ).to(device=device)


    #Edit the neural network model name
    args.name_model = args.name
    
    #Load previously trained model
    if args.test_only:
        load_path = "Exps/{}/{}/".format(args.name, args.seed)
        model_actor = torch.load("{}{}.pkl".format(load_path, 'actor')) # Change the model name accordingly
        actor.load_state_dict(model_actor)
        model_critic = torch.load("{}{}.pkl".format(load_path, 'critic'))  # Change the model name accordingly
        critic.load_state_dict(model_critic)
        model_vf = torch.load("{}{}.pkl".format(load_path, 'value_function'))  # Change the model name accordingly
        value_function.load_state_dict(model_vf)
        print('load model!!!')

    #Initialize preference spaces
    if args.dataset!='cmo':
        w_batch_test = lib.utilities.MORL_utils.generate_w_batch_test(args, step_size = args.w_step_size_final_eval, reward_size=len(env_main.reward_space))
    else:
        w_batch_test = []
        w_batch = np.array([[x,1.0-x] for x in [0.5,0.6,0.7,0.8,0.9,1.0]])
        for cw in np.arange(0, 1.1, 0.1):
            w_batch_test.extend([ np.concatenate([w*(1-cw), np.array([cw])]) for w in w_batch])
        w_batch_test = np.array(w_batch_test)

    #w_batch_eval = lib.utilities.MORL_utils.generate_w_batch_test(args, step_size = args.w_step_size, reward_size=len(env_main.reward_space))

    #Initialize Experience Source and Replay Buffer
    replay_buffer_main = ptan.experience.ReplayBuffer(args)
    train_dataset = env_main.get_dataset(dataset_type)
    test_datasets = env_main.get_test_dataset()

    # Must preprocess train dataset before preprocessing test dataset so that the preprocessing information of train dataset can be used for test dataset 
    preprocessor = ptan.experience.Preprocessor(args)
    train_dataset = preprocessor(train_dataset, env_main, dataset_type=dataset_type, is_test_dataset=False) 
    for i in range(len(test_datasets)):
        test_datasets[i]['demo'] = preprocessor(test_datasets[i]['demo'], env_main, dataset_type=dataset_type, is_test_dataset=True)   
    replay_buffer_main.load_from_dataset(env_main, train_dataset)
    if args.algo=='PEDA' or args.algo=='Prompt-MODT':
        traj_buffer_main = ptan.experience.TrajReplayBuffer(args) if args.algo=='PEDA' else ptan.experience.PromptTrajReplayBuffer(args, w_batch_test, args.adpt_batch_demo)
        traj_buffer_main.load_from_dataset(env_main, train_dataset)
    mean, std = preprocessor.mean, preprocessor.std
    state_normalizer = lambda state: torch.tensor(np.array((state-mean)/std, dtype=np.float32)) 
    if args.algo=='PEDA':
        agent_main = ptan.agent.MOOF_PEDA_AGENT(actor, args, traj_buffer_main, env_main, mean, std, preprocessor.reward_scale_weight)
    elif args.algo=='Prompt-MODT':
        agent_main = ptan.agent.MOOF_PROMPT_MODT_AGENT(actor, args, traj_buffer_main, env_main, mean, std, preprocessor.reward_scale_weight)
    else:
        agent_main = ptan.agent.MOOF_QL_AGENT(actor, critic, value_function, device, behavior_prior, args, state_normalizer)

    time_step = 0
    if not args.test_only:
        # Main Loop
        done_episodes = 0
        time_step = 0
        eval_cnt = 1
        eval_cnt_ep = 1
        for ts in tqdm(range(0, args.time_steps), mininterval=10): #iterate through the fixed number of timesteps
            # Learn from the minibatch
            if args.algo=='PEDA' or args.algo=='Prompt-MODT':
                agent_main.train_step(traj_buffer_main, writer)
            elif args.algo=='IQL':
                agent_main.train_iql(replay_buffer_main, writer)
            elif args.algo=='BC':
                agent_main.train_bc_policy(replay_buffer_main, writer)
            else:
                agent_main.train_regularized_policy(replay_buffer_main, writer)

            time_step = ts 
            # Evaluate agent
            if ts > args.eval_freq*eval_cnt:
                test_time_start = time.time()
                eval_cnt +=1
                print(f"Time steps: {time_step}, Episode Count of Each Process: {time_step}")

                #hypervolume, sparsity, objs = lib.utilities.MORL_utils.eval_agent(test_env_main, env_main, agent_main, w_batch_eval, args, time_step, eval_episodes=args.eval_episodes)
                # hypervolume, sparsity, objs = lib.utilities.MORL_utils.eval_agent_adaptation(
                #     test_env_main, env_main, agent_main, w_batch_eval, args, time_step, train_buffer_info, test_datasets, eval_episodes=args.eval_episodes)

                lib.utilities.common_utils.save_model(actor, args, name = args.name, ext ='actor')
                lib.utilities.common_utils.save_model(critic, args, name = args.name, ext ='critic')
                lib.utilities.common_utils.save_model(value_function, args, name = args.name, ext ='value_function')
                print(f"Eval time: {time.time()-test_time_start}")
        
        print(f"Total Number of Time Steps: {time_step}")
        lib.utilities.common_utils.save_model(actor, args, name = args.name, ext ='actor')
        lib.utilities.common_utils.save_model(critic, args, name = args.name, ext ='critic')
        lib.utilities.common_utils.save_model(value_function, args, name = args.name, ext ='value_function')
    
    if args.algo=='BC':
        adaptation_method = lib.utilities.MORL_utils.FinetuneAdaptation(replay_buffer_main, train_dataset, agent_main, args)
    elif args.algo=='Prompt-MODT':
        adaptation_method = lib.utilities.MORL_utils.PromptMODTAdaptation(replay_buffer_main, train_dataset, agent_main, args)
    else:
        adaptation_method = lib.utilities.MORL_utils.PrefDistAdaptation(replay_buffer_main, train_dataset, agent_main, args)
    
    del replay_buffer_main, train_dataset
    if args.algo=='PEDA' or args.algo=='Prompt-MODT': del traj_buffer_main
    gc.collect()
    hypervolume, sparsity, objs = lib.utilities.MORL_utils.eval_agent_adaptation(
            adaptation_method, test_env_main, env_main, agent_main, args, time_step, test_datasets, eval_episodes=args.eval_episodes)

    if args.algo!='BC' and args.algo!='Prompt-MODT':
        hypervolume, sparsity, objs = lib.utilities.MORL_utils.eval_agent(test_env_main, env_main, agent_main, test_datasets, w_batch_test, args, time_step, eval_episodes=args.eval_episodes)

    print("Time Consumed")
    print("%0.2f minutes" % ((time.time() - start_time)/60))
    if USE_WANDB: wandb.finish()
    

    
